#!/usr/bin/env python
import numpy as np
import tensorflow as tf
import time
import vae_gan_models as VMG
import pickle
#import vae_gan_models_shallow as VMG
#import vae_gan_models_deep as VMG
#import vae_gan_models_full_connection as VMG

import data_read as d
#
#MODEL_DIRECTORY = "store_final/model.ckpt_6"

MODEL_DIRECTORY = "stores_sense/model.ckpt_g6-3"


def VAE_generator(z, t):

    batch_size      =   1
    input_len       =   800   
    output_len      =   200  

    tf.reset_default_graph() 
    x_vector = tf.placeholder(shape=[None,input_len,input_len,10,1],dtype=tf.float32) 
#    x_vector = tf.placeholder(shape=[None,input_len,input_len,2,1],dtype=tf.float32) 
    T_vector = tf.placeholder(shape=[None,output_len,output_len,10,1],dtype=tf.float32) 
    
    # compute mean and std by encoder and convert to z-latent   
    z_latent, mean, stddev = VMG.encoder(x_vector, phase_train=True, reuse=False, training=False)    
    
    # compute out-image resulting from z-latent     
    out_image = VMG.generater(z_latent,batch_size, phase_train=True, reuse=False,training=False) 

    # compute the PNSR 
#    correct_prediction = tf.reduce_mean(tf.square(T_vector - out_image)) 

#     # for validation 2    
    correct_prediction = tf.reduce_mean(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] )) 
    maxerror=tf.reduce_max(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] )) 
    minerror=tf.reduce_min(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] )) 
    
    max99=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),99) 
    
    max95=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),95)     
    
    max90=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),90) 
    
    max75=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),75) 

    max50=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),50)     
    
    max25=tf.contrib.distributions.percentile(tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] ),25) 
    error = tf.square(T_vector[:,:,:,:4,:] - out_image[:,:,:,:4,:] )

    
    pnsr = -10.0*tf.log(correct_prediction)/tf.log(10.)

    # compute ssim
    mean_y_=tf.reduce_mean(T_vector)
    mean_y=tf.reduce_mean(out_image)
    var_y_=tf.reduce_mean((T_vector-mean_y_)**2)
    var_y=tf.reduce_mean((out_image-mean_y)**2)
    covy_y=tf.reduce_mean((out_image-mean_y)*(T_vector-mean_y_))
    SSIM_up=(2*mean_y_*mean_y+1e-6)*(2*covy_y+1e-6)
    SSIM_down=(mean_y_**2+mean_y**2+1e-6)*(var_y_+var_y+1e-6)
    ssim=SSIM_up/SSIM_down
    
 
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, MODEL_DIRECTORY)

#    image, zq = sess.run([out_image, z_latent], feed_dict={x_vector: z})
    PNSR, SSIM = sess.run([pnsr, ssim], feed_dict={x_vector: z, T_vector: t})
    m90, m75, m50, m25, mine, max95, mean, maxerror = sess.run([max90, max75, max50, max25, minerror,max95,correct_prediction,max99], feed_dict={x_vector: z, T_vector: t})    
    
#    im=image[0,:,:,:,0]

    PNSR, errork = sess.run([pnsr, error], feed_dict={x_vector: z, T_vector: t})

    sess.close()
    
#    return im, zq, PNSR, SSIM
#    return errork, aPNSR
    return m90, m75, m50, m25, mine, max95, mean, maxerror, PNSR, SSIM

#

if __name__ == '__main__':
    
    start_time = time.time()  
#  one realizations for  validation and training test 
#    in_,out_data=d.data_input()
#    in_data=d.data_combine(in_)
#    image, zq, aPNSR, aSSIM=VAE_generator (in_data, out_data)

#    fin=open("sequential_validation_in_data","rb")
#    in_mrvbf=pickle.load(fin)
#    
#    fin=open("sequential_validation_out_data","rb")
#    out_ec=pickle.load(fin)    
#
#    i=19
#    start_in=i*40
#    end_in=start_in+800
#    start_out=i*10
#    end_out=start_out+200
#    in_=in_mrvbf[:,start_in:end_in]
#    out_=out_ec[:,start_out:end_out,:].reshape(1,200,200,10,1)
#    
#    in_data=d.data_combine(in_)
#    
#    
#    error2,aPNSR=VAE_generator (in_data, out_)
#    
#    
    
    
# squential validation   
    fin=open("sequential_validation_in_data","rb")
    in_mrvbf=pickle.load(fin)
    
    fin=open("sequential_validation_out_data","rb")
    out_ec=pickle.load(fin)    
    
    fin.close()
    
    aPNSR=np.zeros(20)
    aSSIM=np.zeros(20)
    
    mine=np.zeros(20)
    maxe=np.zeros(20)
    meane=np.zeros(20)
    m90=np.zeros(20)
    m75=np.zeros(20)
    m50=np.zeros(20)
    m25=np.zeros(20)
    m95=np.zeros(20)
    
    for i in range(20):
        start_in=i*40
        end_in=start_in+800
        start_out=i*10
        end_out=start_out+200
        in_=in_mrvbf[:,start_in:end_in]
        out_=out_ec[:,start_out:end_out,:].reshape(1,200,200,10,1)
        
        in_data=d.data_combine(in_)
        
 #       image, zq, aPNSR[i], aSSIM=VAE_generator (in_data, out_)
        m90[i], m75[i], m50[i], m25[i], mine[i], m95[i], meane[i], maxe[i], aPNSR[i], aSSIM[i]=VAE_generator (in_data, out_)
        
        print(i)

    end_time = time.time()
    
    print("This sampling run took %5.4f seconds." % (end_time - start_time))    
    
#





    


